Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(jax/array-api): dipole/polarizability fitting #4278

Merged
merged 1 commit into from
Oct 31, 2024

Conversation

njzjz
Copy link
Member

@njzjz njzjz commented Oct 30, 2024

Summary by CodeRabbit

Release Notes

  • New Features

    • Introduced DipoleFittingNet and PolarFittingNet classes for enhanced fitting functionality.
    • Expanded support for JAX as a backend for fitting tensors, alongside existing TensorFlow and PyTorch support.
  • Bug Fixes

    • Improved error handling and parameter validation in the DipoleFitting and PolarFitting classes.
  • Documentation

    • Updated documentation to reflect JAX as a supported backend for fitting tensors.
  • Tests

    • Enhanced testing framework to support evaluations with JAX and Array API Strict, including new test methods and properties.

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Copy link
Contributor

coderabbitai bot commented Oct 30, 2024

📝 Walkthrough

Walkthrough

The pull request introduces significant modifications to the DipoleFitting and PolarFitting classes, enhancing their compatibility with various array backends through the inclusion of array_api_compat. Key changes include updates to the call methods to utilize backend-agnostic operations, along with the addition of new fitting network classes (DipoleFittingNet and PolarFittingNet). Furthermore, the documentation is updated to reflect JAX as a supported backend. Testing frameworks are also enhanced to support evaluations across different backends, ensuring comprehensive testing capabilities.

Changes

File Path Change Summary
deepmd/dpmodel/fitting/dipole_fitting.py Modified DipoleFitting class; added array_api_compat, updated call method for backend compatibility, reshaped output using xp.reshape, and changed computation to out @ gr.
deepmd/dpmodel/fitting/polarizability_fitting.py Modified PolarFitting class; added array_api_compat and to_numpy_array, updated constructor for scale, modified serialize and call methods for array API compatibility.
deepmd/jax/fitting/__init__.py Added DipoleFittingNet and PolarFittingNet to the module's public interface.
deepmd/jax/fitting/fitting.py Introduced DipoleFittingNet and PolarFittingNet classes, registered with BaseFitting, and implemented __setattr__ methods for attribute management.
doc/model/train-fitting-tensor.md Updated to include JAX as a supported backend for fitting tensors, with formatting adjustments for clarity.
source/tests/array_api_strict/fitting/fitting.py Added DipoleFittingNet and PolarFittingNet classes, enhanced existing classes with setattr_for_general_fitting for attribute handling.
source/tests/consistent/fitting/test_dipole.py Enhanced testing for dipole fitting; added support for JAX and Array API Strict, with new evaluation methods and properties for conditional test execution.
source/tests/consistent/fitting/test_polar.py Enhanced testing for polar fitting; similarly added support for JAX and Array API Strict, with new evaluation methods and properties for conditional test execution.

Possibly related PRs

Suggested labels

Python, Examples, Docs

Suggested reviewers

  • wanghan-iapcm
  • iProzd

Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

🧹 Outside diff range and nitpick comments (6)
deepmd/jax/fitting/fitting.py (1)

70-82: LGTM with a suggestion: Consider extracting polar-specific attribute handling

The implementation is correct and properly handles polar-specific attributes. However, to improve maintainability, consider extracting the polar-specific attribute handling into a helper function.

Consider refactoring like this:

+def setattr_for_polar_fitting(name: str, value: Any) -> Any:
+    if name in {"scale", "constant_matrix"}:
+        value = to_jax_array(value)
+        if value is not None:
+            value = ArrayAPIVariable(value)
+    return value

 @BaseFitting.register("polar")
 @flax_module
 class PolarFittingNet(PolarFittingNetDP):
     def __setattr__(self, name: str, value: Any) -> None:
         value = setattr_for_general_fitting(name, value)
-        if name in {
-            "scale",
-            "constant_matrix",
-        }:
-            value = to_jax_array(value)
-            if value is not None:
-                value = ArrayAPIVariable(value)
+        value = setattr_for_polar_fitting(name, value)
         return super().__setattr__(name, value)

This change would:

  1. Make the code more maintainable by isolating polar-specific logic
  2. Make it easier to test the attribute handling separately
  3. Follow the same pattern as setattr_for_general_fitting
source/tests/consistent/fitting/test_dipole.py (2)

167-175: Consider adding type hints for return value.

The JAX evaluation implementation is clean and consistent with other backends.

Consider adding a return type hint for better type safety:

-    def eval_jax(self, jax_obj: Any) -> Any:
+    def eval_jax(self, jax_obj: Any) -> np.ndarray:

177-185: Consider adding type hints for return value.

The Array API Strict evaluation implementation is clean and consistent with other backends.

Consider adding a return type hint for better type safety:

-    def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any:
+    def eval_array_api_strict(self, array_api_strict_obj: Any) -> np.ndarray:
doc/model/train-fitting-tensor.md (1)

Line range hint 1-999: Document needs JAX-specific sections.

The document should be updated to include JAX-specific information in several sections:

  1. Add a JAX tab in the examples section showing paths to JAX input files
  2. Add JAX-specific configuration examples in the "The fitting Network" section
  3. Add JAX command in the "Train the Model" section

Example addition for the "Train the Model" section:

 :::::{tab-set}
 
 :::{tab-item} TensorFlow {{ tensorflow_icon }}
 
 ```bash
 dp train input.json

:::

:::{tab-item} PyTorch {{ pytorch_icon }}

dp --pt train input.json

:::

+:::{tab-item} JAX {{ jax_icon }}
+
+bash +dp --jax train input.json +
+
+:::
+
::::


</details>
<details>
<summary>source/tests/consistent/fitting/test_polar.py (2)</summary>

`89-90`: **Consider shortening variable name `array_api_strict_class` for readability.**

The variable `array_api_strict_class` is descriptive but somewhat long. For improved readability, consider renaming it to align with naming conventions.



For example:

```diff
 jax_class = PolarFittingJAX
-array_api_strict_class = PolarFittingArrayAPIStrict
+api_strict_class = PolarFittingArrayAPIStrict

This makes the variable name shorter while still conveying its purpose.


92-93: Add docstrings for new properties skip_jax and skip_array_api_strict.

The properties skip_jax and skip_array_api_strict are introduced without docstrings. Adding brief docstrings would enhance code readability and maintainability.

For example:

 skip_jax = not INSTALLED_JAX
+"""Indicates whether to skip JAX tests based on installation."""
 skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT
+"""Indicates whether to skip Array API Strict tests based on installation."""
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between d165fee and 2b4898f.

📒 Files selected for processing (8)
  • deepmd/dpmodel/fitting/dipole_fitting.py (2 hunks)
  • deepmd/dpmodel/fitting/polarizability_fitting.py (6 hunks)
  • deepmd/jax/fitting/init.py (1 hunks)
  • deepmd/jax/fitting/fitting.py (2 hunks)
  • doc/model/train-fitting-tensor.md (1 hunks)
  • source/tests/array_api_strict/fitting/fitting.py (2 hunks)
  • source/tests/consistent/fitting/test_dipole.py (4 hunks)
  • source/tests/consistent/fitting/test_polar.py (4 hunks)
🔇 Additional comments (18)
deepmd/jax/fitting/__init__.py (1)

3-6: LGTM! Clean and well-structured changes.

The additions follow the module's existing patterns and maintain proper organization.

Let's verify the existence of the newly imported classes:

Also applies to: 12-13

✅ Verification successful

Classes are properly implemented and exposed

The verification confirms that both DipoleFittingNet and PolarFittingNet classes are:

  • Properly defined in deepmd/jax/fitting/fitting.py
  • Correctly registered with @BaseFitting.register decorators
  • Inherit from their respective DP model counterparts
  • Follow the same pattern as existing fitting networks
  • Properly exposed in the module's __init__.py
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify that the new fitting network classes exist in the source module

# Test: Check if the new classes are defined in the fitting module
ast-grep --pattern 'class DipoleFittingNet' 
ast-grep --pattern 'class PolarFittingNet'

Length of output: 90


Script:

#!/bin/bash
# Let's try a broader search to locate these classes in the codebase

# Search for files containing DipoleFittingNet
echo "=== Files containing DipoleFittingNet ==="
rg "DipoleFittingNet" -l

# Search for files containing PolarFittingNet
echo -e "\n=== Files containing PolarFittingNet ==="
rg "PolarFittingNet" -l

# Get content of any Python files in the fitting directory
echo -e "\n=== Content of Python files in fitting directory ==="
fd "\.py$" "deepmd/jax/fitting" --exec cat {}

Length of output: 4356

source/tests/array_api_strict/fitting/fitting.py (2)

6-11: LGTM: Import statements are well-organized

The new imports for dipole and polarizability fitting components follow the established pattern and are correctly placed.


52-55: LGTM: DipoleFittingNet implementation is consistent

The implementation follows the established pattern of using setattr_for_general_fitting for array backend compatibility, maintaining consistency with other fitting networks.

deepmd/jax/fitting/fitting.py (2)

6-6: LGTM: Clean import additions

The new imports for dipole and polar fitting base classes are well-organized and properly aliased.

Also applies to: 9-11


62-67: LGTM: Well-structured DipoleFittingNet implementation

The implementation follows the established pattern, properly integrating with the fitting framework through registration and maintaining consistency with other fitting networks.

source/tests/consistent/fitting/test_dipole.py (2)

15-16: LGTM! Well-structured backend integration.

The new imports and class properties for JAX and Array API Strict backends are cleanly integrated, following the existing patterns for conditional imports and skip flags.

Also applies to: 37-51, 89-93


92-93: Verify test coverage across all backends.

Let's ensure all test cases are being executed for each backend.

✅ Verification successful

Let me gather more information about the test coverage and backend handling.


Let me check the actual test implementations to understand how different backends are being tested.


Based on the comprehensive test file content, I can now provide the final response:

Test coverage across backends is properly implemented

The test implementation shows a robust framework for testing all backends:

  • Tests are parameterized with different configurations (resnet_dt, precision, mixed_types)
  • Each backend (TF, PT, JAX, Array API Strict) has dedicated evaluation methods
  • Skip conditions are properly handled for optional backends (JAX, Array API Strict)
  • Common test infrastructure ensures consistent testing across all available backends
  • Appropriate tolerance levels are set based on precision
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check test coverage for dipole fitting across backends

# Test: Look for test execution patterns across different backends
echo "Checking test patterns across backends..."
rg -A 5 "def test_" source/tests/consistent/fitting/test_dipole.py

# Test: Verify skip conditions are properly handled
echo "Checking skip conditions..."
rg "skip_.*=.*not.*INSTALLED" source/tests/consistent/fitting/test_dipole.py

Length of output: 397


Script:

#!/bin/bash
# Check for test decorators and backend-specific test implementations
echo "Checking test decorators and backend handling..."
rg -B 2 -A 10 "@pytest" source/tests/consistent/fitting/test_dipole.py

# Look for backend-specific imports and configurations
echo "Checking backend-related imports and configurations..."
rg "INSTALLED_|import|backend" source/tests/consistent/fitting/test_dipole.py

# Check for any parametrized tests that might handle different backends
echo "Checking for parametrized tests..."
rg -B 2 -A 5 "@pytest.mark.parametrize" source/tests/consistent/fitting/test_dipole.py

Length of output: 1584


Script:

#!/bin/bash
# Get the full content of the test file to understand the test structure
echo "Checking full test implementation..."
cat source/tests/consistent/fitting/test_dipole.py

# Specifically look for test methods and their skip decorators
echo "Checking test methods and skip decorators..."
rg -B 1 -A 3 "(\s+def test_|@pytest.mark.skip)" source/tests/consistent/fitting/test_dipole.py

Length of output: 6222

doc/model/train-fitting-tensor.md (1)

1-1: LGTM: JAX backend added to title.

The addition of the JAX icon to the title correctly reflects the new backend support.

source/tests/consistent/fitting/test_polar.py (3)

167-176: Ensure proper conversion of inputs in eval_jax method.

The eval_jax method correctly converts inputs to JAX arrays using jnp.asarray. This ensures compatibility with JAX operations.


177-186: Ensure proper conversion of inputs in eval_array_api_strict method.

The eval_array_api_strict method uses array_api_strict.asarray to convert inputs, which is appropriate for Array API Strict compatibility.


47-49: ⚠️ Potential issue

Verify relative import path for PolarFittingArrayAPIStrict.

Please check if the relative import path is correct in the following import statement:

from ...array_api_strict.fitting.fitting import (
    PolarFittingNet as PolarFittingArrayAPIStrict,
)

Using '...' in relative imports moves up three levels in the package hierarchy. Ensure that this path correctly locates PolarFittingNet.

To confirm the validity of the import path, you can run the following script:

✅ Verification successful

The relative import path is correct

The import path from ...array_api_strict.fitting.fitting import PolarFittingNet as PolarFittingArrayAPIStrict is verified to be correct because:

  1. The importing file is at source/tests/consistent/fitting/test_polar.py
  2. Going up three levels (...) from this location leads to source/tests/
  3. The target file exists at source/tests/array_api_strict/fitting/fitting.py
  4. The PolarFittingNet class is defined in the target file
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify the existence of 'PolarFittingNet' in the specified path.

# Expected: The path leads to a file defining 'PolarFittingNet'.

# Search for 'PolarFittingNet' class definition in 'fitting.py' files.
fd --type f --name 'fitting.py' --exec grep -H 'class PolarFittingNet' {}

Length of output: 326


Script:

#!/bin/bash
# Let's try a different approach to locate the file and class definition

# First, find the current file location
fd "test_polar.py$"

# Then search for fitting.py files
fd "fitting.py$"

# Finally, let's use ast-grep to find the class definition
ast-grep --pattern 'class PolarFittingNet {
  $$$
}'

Length of output: 1168


Script:

#!/bin/bash
# Let's examine the content of the relevant files to trace the import path

# First, check the content of test_polar.py to understand the context
cat source/tests/consistent/fitting/test_polar.py

# Then check the target file where PolarFittingNet should be defined
cat source/tests/array_api_strict/fitting/fitting.py

Length of output: 8046

deepmd/dpmodel/fitting/dipole_fitting.py (4)

9-9: Addition of array_api_compat import for backend compatibility

The import of array_api_compat enhances the code's compatibility with different array backends, facilitating backend-agnostic operations.


211-211: Ensure consistency of input arrays when obtaining array namespace

When using xp = array_api_compat.array_namespace(descriptor, atype), please verify that descriptor and atype are from the same array backend to prevent any potential inconsistencies during computations.


219-221: Use of xp.reshape for backend-agnostic reshaping

Replacing np.reshape with xp.reshape ensures that reshaping operations are compatible with the array backend in use.


223-225: Confirm equivalence of matrix multiplication to original np.einsum operation

The replacement of the commented out np.einsum with out = out @ gr followed by reshaping simplifies the code and leverages the matrix multiplication operator. Please confirm that out and gr have compatible shapes and that this operation yields the same results as the original einsum expression.

deepmd/dpmodel/fitting/polarizability_fitting.py (3)

9-9: Import array_api_compat for array backend compatibility

The addition of import array_api_compat ensures compatibility with various array backends, which is appropriate for enhancing flexibility.


18-20: Import to_numpy_array for consistent serialization

Importing to_numpy_array from deepmd.dpmodel.common facilitates consistent serialization of arrays, which is good practice.


131-142: ⚠️ Potential issue

Fix missing handling in elif isinstance(scale, float):

There is no code under the elif isinstance(scale, float): condition. This will result in a SyntaxError or unintended behavior. You should handle this case by converting the float scale into a list with length equal to ntypes.

Apply this diff to fix the missing handling:

     elif isinstance(scale, float):
+        scale = [scale for _ in range(ntypes)]
     else:
         raise ValueError(
             "Scale must be a list of float of length ntypes or a float."
         )

Likely invalid or redundant comment.

deepmd/jax/fitting/fitting.py Show resolved Hide resolved
doc/model/train-fitting-tensor.md Show resolved Hide resolved
source/tests/consistent/fitting/test_polar.py Show resolved Hide resolved
Copy link

codecov bot commented Oct 30, 2024

Codecov Report

Attention: Patch coverage is 98.03922% with 1 line in your changes missing coverage. Please review.

Project coverage is 84.30%. Comparing base (159361d) to head (2b4898f).
Report is 7 commits behind head on devel.

Files with missing lines Patch % Lines
deepmd/dpmodel/fitting/polarizability_fitting.py 96.29% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##            devel    #4278      +/-   ##
==========================================
- Coverage   84.37%   84.30%   -0.07%     
==========================================
  Files         551      553       +2     
  Lines       51585    51844     +259     
  Branches     3052     3052              
==========================================
+ Hits        43524    43707     +183     
- Misses       7100     7177      +77     
+ Partials      961      960       -1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Collaborator

@anyangml anyangml left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants